{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "063850ab-22b0-4838-b53a-9bb11757d9d0",
   "metadata": {},
   "source": [
    "# Embedding层和 Linear层"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0315c598-701f-46ff-8806-15813cad0e51",
   "metadata": {},
   "source": [
    "- 在PyTorch中，嵌入层（Embedding layers）实现了执行矩阵乘法的线性层的相同功能；我们使用嵌入层的原因是为了提高计算效率。\n",
    "- 我们将逐步使用PyTorch中的代码示例来查看这种关系。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "061720f4-f025-4640-82a0-15098fa94cf9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "PyTorch version: 1.12.1+cu113\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "print(\"PyTorch version:\", torch.__version__)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7895a66-7f69-4f62-9361-5c9da2eb76ef",
   "metadata": {},
   "source": [
    "## Using nn.Embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "cc489ea5-73db-40b9-959e-0d70cae25f40",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 假设我们有以下 3 个训练样本，\n",
    "# 这些样本可能表示语言模型（LM）上下文中的标记ID\n",
    "idx = torch.tensor([2, 3, 1])\n",
    "\n",
    "# 嵌入矩阵的行数可以通过获取最大标记ID + 1 来确定。\n",
    "# 如果最高的标记ID是3，则我们希望有4行，对应可能的\n",
    "# 标记ID 0, 1, 2, 3\n",
    "num_idx = max(idx) + 1\n",
    "\n",
    "# 所需的嵌入维度是一个超参数\n",
    "out_dim = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "93d83a6e-8543-40af-b253-fe647640bf36",
   "metadata": {},
   "source": [
    "- 实现一个简单的嵌入层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "60a7c104-36e1-4b28-bd02-a24a1099dc66",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 为了可重复性，我们使用随机种子，\n",
    "# 因为嵌入层的权重是用小的随机值初始化的\n",
    "torch.manual_seed(123)\n",
    "\n",
    "# 创建一个嵌入层，指定输入维度为 num_idx，输出维度为 out_dim\n",
    "embedding = torch.nn.Embedding(num_idx, out_dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd96c00a-3297-4a50-8bfc-247aaea7e761",
   "metadata": {},
   "source": [
    "查看嵌入权重数据情况"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "595f603e-8d2a-4171-8f94-eac8106b2e57",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[ 0.3374, -0.1778, -0.3035, -0.5880,  1.5810],\n",
       "        [ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015],\n",
       "        [ 0.6957, -1.8061, -1.1589,  0.3255, -0.6315],\n",
       "        [-2.8400, -0.7849, -1.4096, -0.4076,  0.7953]], requires_grad=True)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding.weight"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c86eb562-61e2-4171-ab6e-b410a1fd5c18",
   "metadata": {},
   "source": [
    "- 使用嵌入层来获取具有ID 1的训练样本的向量表示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8bbc0255-4805-4be9-9f4c-1d0d967ef9d5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015]],\n",
       "       grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding(torch.tensor([1]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a4d47f2-4691-47b8-9855-2528b6c285c9",
   "metadata": {},
   "source": [
    "- 下面是底层操作的可视化"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12ffd155-7190-44b1-b6b6-45b11d6fe83b",
   "metadata": {},
   "source": [
    "<img src=\"images/1.png\" width=\"400px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87d1311b-cfb2-4afc-9e25-e4ecf35370d9",
   "metadata": {},
   "source": [
    "- 同样，我们可以使用嵌入层来获取具有ID 2的训练样本的向量表示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c309266a-c601-4633-9404-2e10b1cdde8c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.6957, -1.8061, -1.1589,  0.3255, -0.6315]],\n",
       "       grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding(torch.tensor([2]))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ad3b601-f68c-41b1-a28d-b624b94ef383",
   "metadata": {},
   "source": [
    "<img src=\"images/2.png\" width=\"400px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "27dd54bd-85ae-4887-9c5e-3139da361cf4",
   "metadata": {},
   "source": [
    "- 现在，让我们将之前定义的所有训练样本转换："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0191aa4b-f6a8-4b0d-9c36-65e82b81d071",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.6957, -1.8061, -1.1589,  0.3255, -0.6315],\n",
       "        [-2.8400, -0.7849, -1.4096, -0.4076,  0.7953],\n",
       "        [ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015]],\n",
       "       grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将原先的第三行变成现在的第一行，第四行变成现在的第二行，第二行变成现在的第三行\n",
    "idx = torch.tensor([2, 3, 1])\n",
    "embedding(idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "146cf8eb-c517-4cd4-aa91-0e818fed7651",
   "metadata": {},
   "source": [
    "- Under the hood, it's still the same look-up concept:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b392eb43-0bda-4821-b446-b8dcbee8ae00",
   "metadata": {},
   "source": [
    "<img src=\"images/3.png\" width=\"450px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f0fe863b-d6a3-48f3-ace5-09ecd0eb7b59",
   "metadata": {},
   "source": [
    "## 使用 nn.Linear"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "138de6a4-2689-4c1f-96af-7899b2d82a4e",
   "metadata": {},
   "source": [
    "- 接下来，我们将使用One-Hot编码，与embedding 层一样，在 `nn.Linear` 层进行操作\n",
    "- 首先，我们将标记ID转换为One-Hot表示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b5bb56cf-bc73-41ab-b107-91a43f77bdba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0, 1, 0],\n",
       "        [0, 0, 0, 1],\n",
       "        [0, 1, 0, 0]])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "onehot = torch.nn.functional.one_hot(idx)\n",
    "onehot"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa45dfdf-fb26-4514-a176-75224f5f179b",
   "metadata": {},
   "source": [
    "- 接下来，我们使用矩阵乘法$X W^\\top$ 来初始化一个Linear层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ae04c1ed-242e-4dd7-b8f7-4b7e4caae383",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parameter containing:\n",
      "tensor([[-0.2039,  0.0166, -0.2483,  0.1886],\n",
      "        [-0.4260,  0.3665, -0.3634, -0.3975],\n",
      "        [-0.3159,  0.2264, -0.1847,  0.1871],\n",
      "        [-0.4244, -0.3034, -0.1836, -0.0983],\n",
      "        [-0.3814,  0.3274, -0.1179,  0.1605]], requires_grad=True)\n"
     ]
    }
   ],
   "source": [
    "torch.manual_seed(123)\n",
    "# 初始化一个Linear层，该层的权重矩阵是由 num_idx（输入维度）到 out_dim（输出维度）的一个线性层，而且没有偏置项\n",
    "linear = torch.nn.Linear(num_idx, out_dim, bias=False)\n",
    "print(linear.weight)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63efb98e-5cc4-4e8d-9fe6-ef0ad29ae2d7",
   "metadata": {},
   "source": [
    "- 请注意，PyTorch中的`linear`层也是用小的随机权重进行初始化的。为了与上面的 `Embedding` 层进行直接比较，我们必须使用相同的小随机权重，这就是我们在这里重新分配它们的原因："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "a3b90d69-761c-486e-bd19-b38a2988fe62",
   "metadata": {},
   "outputs": [],
   "source": [
    "# linear 层的权重就被重新赋值为与 embedding 层相同的小随机权重，以确保它们具有相同的初始化。这是为了使它们在后续操作中可以进行直接比较。\n",
    "linear.weight = torch.nn.Parameter(embedding.weight.T.detach())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9116482d-f1f9-45e2-9bf3-7ef5e9003898",
   "metadata": {},
   "source": [
    "- 现在，我们可以使用线性层处理输入的One-Hot编码表示："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "90d2b0dd-9f1d-4c0f-bb16-1f6ce6b8ac2c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.6957, -1.8061, -1.1589,  0.3255, -0.6315],\n",
       "        [-2.8400, -0.7849, -1.4096, -0.4076,  0.7953],\n",
       "        [ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015]], grad_fn=<MmBackward0>)"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear(onehot.float())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f6204bc8-92e2-4546-9cda-574fe1360fa2",
   "metadata": {},
   "source": [
    "正如我们所看到的，这与我们使用嵌入层时得到的结果完全相同："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "2b057649-3176-4a54-b58c-fd8fbf818c61",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.6957, -1.8061, -1.1589,  0.3255, -0.6315],\n",
       "        [-2.8400, -0.7849, -1.4096, -0.4076,  0.7953],\n",
       "        [ 1.3010,  1.2753, -0.2010, -0.1606, -0.4015]],\n",
       "       grad_fn=<EmbeddingBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding(idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e447639-8952-460e-8c8f-cf9e23c368c9",
   "metadata": {},
   "source": [
    "- 底层发生的计算如下，针对第一个训练样本的标记ID："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1830eccf-a707-4753-a24a-9b103f55594a",
   "metadata": {},
   "source": [
    "<img src=\"images/4.png\" width=\"450px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ce5211a-14e6-46aa-a3a8-14609f086e97",
   "metadata": {},
   "source": [
    "- 以及对于第二个训练样本的标记ID："
   ]
  },
  {
   "cell_type": "markdown",
   "id": "173f6026-a461-44da-b9c5-f571f8ec8bf3",
   "metadata": {},
   "source": [
    "<img src=\"images/5.png\" width=\"450px\">"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2608049-f5d1-49a9-a14b-82695fc32e6a",
   "metadata": {},
   "source": [
    "- \n",
    "由于每个独热编码行中除了一个索引外都为0（设计如此），这个矩阵乘法本质上就是对独热编码元素的查找\n",
    "- 。在独热编码上使用矩阵乘法与使用嵌入层查找是等效的，但如果我们使用大型嵌入矩阵，这种方法可能效率较低，因为有很多不必要的零乘法。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5eacc005-86fc-490c-8f6a-dc37d8a0df7c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
